!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routines to calculate EXX within GW
!> \par History
!>      07.2020 separated from mp2.F [F. Stein, code by Jan Wilhelm]
!>      07.2024 determine number of corrected MOs from BSE cutoffs [Maximilian Graml]
!> \author Jan Wilhelm, Frederick Stein
! **************************************************************************************************
MODULE rpa_gw_sigma_x
   USE admm_methods,                    ONLY: admm_mo_merge_ks_matrix
   USE admm_types,                      ONLY: admm_type,&
                                              get_admm_env
   USE bse_util,                        ONLY: determine_cutoff_indices
   USE cp_cfm_basic_linalg,             ONLY: cp_cfm_scale_and_add_fm
   USE cp_cfm_types,                    ONLY: cp_cfm_create,&
                                              cp_cfm_get_info,&
                                              cp_cfm_release,&
                                              cp_cfm_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_api,                    ONLY: &
        dbcsr_add, dbcsr_copy, dbcsr_create, dbcsr_desymmetrize, dbcsr_multiply, dbcsr_p_type, &
        dbcsr_release, dbcsr_release_p, dbcsr_set, dbcsr_type, dbcsr_type_antisymmetric, &
        dbcsr_type_symmetric
   USE cp_dbcsr_contrib,                ONLY: dbcsr_get_diag
   USE cp_dbcsr_cp2k_link,              ONLY: cp_dbcsr_alloc_block_from_nbl
   USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm,&
                                              copy_fm_to_dbcsr,&
                                              dbcsr_allocate_matrix_set,&
                                              dbcsr_deallocate_matrix_set
   USE cp_files,                        ONLY: close_file,&
                                              open_file
   USE cp_fm_struct,                    ONLY: cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_info,&
                                              cp_fm_release,&
                                              cp_fm_type
   USE hfx_energy_potential,            ONLY: integrate_four_center
   USE hfx_exx,                         ONLY: calc_exx_admm_xc_contributions,&
                                              exx_post_hfx,&
                                              exx_pre_hfx
   USE hfx_ri,                          ONLY: hfx_ri_update_ks
   USE input_constants,                 ONLY: do_admm_basis_projection,&
                                              do_admm_purify_none,&
                                              gw_print_exx,&
                                              gw_read_exx,&
                                              xc_none
   USE input_section_types,             ONLY: section_vals_get,&
                                              section_vals_get_subs_vals,&
                                              section_vals_type,&
                                              section_vals_val_get,&
                                              section_vals_val_set
   USE kinds,                           ONLY: dp
   USE kpoint_methods,                  ONLY: rskp_transform
   USE kpoint_types,                    ONLY: get_kpoint_info,&
                                              kpoint_env_type,&
                                              kpoint_type
   USE machine,                         ONLY: m_walltime
   USE mathconstants,                   ONLY: gaussi,&
                                              z_one,&
                                              z_zero
   USE message_passing,                 ONLY: mp_para_env_type
   USE mp2_integrals,                   ONLY: compute_kpoints
   USE mp2_ri_2c,                       ONLY: trunc_coulomb_for_exchange
   USE mp2_types,                       ONLY: mp2_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE physcon,                         ONLY: evolt
   USE qs_energy_types,                 ONLY: qs_energy_type
   USE qs_environment_types,            ONLY: get_qs_env,&
                                              qs_environment_type
   USE qs_ks_methods,                   ONLY: qs_ks_build_kohn_sham_matrix
   USE qs_ks_types,                     ONLY: qs_ks_env_type
   USE qs_mo_types,                     ONLY: get_mo_set,&
                                              mo_set_type
   USE qs_neighbor_list_types,          ONLY: neighbor_list_set_p_type
   USE qs_rho_types,                    ONLY: qs_rho_get,&
                                              qs_rho_type
   USE rpa_gw,                          ONLY: compute_minus_vxc_kpoints,&
                                              trafo_to_mo_and_kpoints
   USE rpa_gw_kpoints_util,             ONLY: get_bandstruc_and_k_dependent_MOs

!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num, omp_get_num_threads

#include "./base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rpa_gw_sigma_x'

   PUBLIC :: compute_vec_Sigma_x_minus_vxc_gw

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param qs_env ...
!> \param mp2_env ...
!> \param mos_mp2 ...
!> \param energy_ex ...
!> \param energy_xc_admm ...
!> \param t3 ...
!> \param unit_nr ...
!> \par History
!>      04.2015 created
!> \author Jan Wilhelm
! **************************************************************************************************
   SUBROUTINE compute_vec_Sigma_x_minus_vxc_gw(qs_env, mp2_env, mos_mp2, energy_ex, energy_xc_admm, t3, unit_nr)
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(mp2_type)                                     :: mp2_env
      TYPE(mo_set_type), DIMENSION(:), INTENT(IN)        :: mos_mp2
      REAL(KIND=dp), INTENT(OUT)                         :: energy_ex, energy_xc_admm(2), t3
      INTEGER, INTENT(IN)                                :: unit_nr

      CHARACTER(len=*), PARAMETER :: routineN = 'compute_vec_Sigma_x_minus_vxc_gw'

      CHARACTER(4)                                       :: occ_virt
      CHARACTER(LEN=40)                                  :: line
      INTEGER :: dimen, gw_corr_lev_occ, gw_corr_lev_tot, gw_corr_lev_virt, handle, homo, &
         homo_reduced_bse, homo_startindex_bse, i_img, ikp, irep, ispin, iunit, max_corr_lev_occ, &
         max_corr_lev_virt, myfun, myfun_aux, myfun_prim, n_level_gw, n_level_gw_ref, n_rep_hf, &
         nkp, nkp_Sigma, nmo, nspins, print_exx, virtual_reduced_bse, virtual_startindex_bse
      LOGICAL :: calc_ints, charge_constrain_tmp, do_admm_rpa, do_hfx, do_kpoints_cubic_RPA, &
         do_kpoints_from_Gamma, do_ri_Sigma_x, really_read_line
      REAL(KIND=dp) :: E_GAP_GW, E_HOMO_GW, E_LUMO_GW, eh1, ehfx, eigval_dft, eigval_hf_at_dft, &
         energy_exc, energy_exc1, energy_exc1_aux_fit, energy_exc_aux_fit, energy_total, &
         exx_minus_vxc, hfx_fraction, min_direct_HF_at_DFT_gap, t1, t2, tmp
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: matrix_tmp_2_diag
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: Eigenval_kp_HF_at_DFT, vec_Sigma_x
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: Eigenval_kp, vec_Sigma_x_minus_vxc_gw, &
                                                            vec_Sigma_x_minus_vxc_gw_im
      REAL(KIND=dp), DIMENSION(:), POINTER               :: mo_eigenvalues
      TYPE(admm_type), POINTER                           :: admm_env
      TYPE(cp_fm_type), POINTER                          :: mo_coeff
      TYPE(dbcsr_p_type), ALLOCATABLE, DIMENSION(:)      :: mat_exchange_for_kp_from_gamma
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_ks, matrix_ks_aux_fit, &
                                                            matrix_ks_aux_fit_hfx, rho_ao, &
                                                            rho_ao_aux_fit
      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER :: matrix_ks_2d, matrix_ks_kp_im, &
         matrix_ks_kp_re, matrix_ks_transl, matrix_sigma_x_minus_vxc, matrix_sigma_x_minus_vxc_im, &
         rho_ao_2d
      TYPE(dbcsr_type)                                   :: matrix_tmp, matrix_tmp_2, mo_coeff_b
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(kpoint_type), POINTER                         :: kpoints, kpoints_Sigma
      TYPE(mp_para_env_type), POINTER                    :: para_env
      TYPE(qs_energy_type), POINTER                      :: energy
      TYPE(qs_ks_env_type), POINTER                      :: ks_env
      TYPE(qs_rho_type), POINTER                         :: rho, rho_aux_fit
      TYPE(section_vals_type), POINTER                   :: hfx_sections, input, xc_section, &
                                                            xc_section_admm_aux, &
                                                            xc_section_admm_prim

      NULLIFY (admm_env, matrix_ks, matrix_ks_aux_fit, rho_ao, matrix_sigma_x_minus_vxc, input, &
               xc_section, xc_section_admm_aux, xc_section_admm_prim, hfx_sections, rho, &
               dft_control, para_env, ks_env, mo_coeff, matrix_sigma_x_minus_vxc_im, matrix_ks_aux_fit_hfx, &
               rho_aux_fit, rho_ao_aux_fit)

      CALL timeset(routineN, handle)

      t1 = m_walltime()

      do_admm_rpa = mp2_env%ri_rpa%do_admm
      do_ri_Sigma_x = mp2_env%ri_g0w0%do_ri_Sigma_x
      do_kpoints_cubic_RPA = qs_env%mp2_env%ri_rpa_im_time%do_im_time_kpoints
      do_kpoints_from_Gamma = qs_env%mp2_env%ri_rpa_im_time%do_kpoints_from_Gamma
      print_exx = mp2_env%ri_g0w0%print_exx

      IF (do_kpoints_cubic_RPA) THEN
         CPASSERT(do_ri_Sigma_x)
      END IF

      IF (do_kpoints_cubic_RPA) THEN

         CALL get_qs_env(qs_env, &
                         admm_env=admm_env, &
                         matrix_ks_kp=matrix_ks_transl, &
                         rho=rho, &
                         input=input, &
                         dft_control=dft_control, &
                         para_env=para_env, &
                         kpoints=kpoints, &
                         ks_env=ks_env, &
                         energy=energy)
         nkp = kpoints%nkp

      ELSE

         CALL get_qs_env(qs_env, &
                         admm_env=admm_env, &
                         matrix_ks=matrix_ks, &
                         rho=rho, &
                         input=input, &
                         dft_control=dft_control, &
                         para_env=para_env, &
                         ks_env=ks_env, &
                         energy=energy)
         nkp = 1
         CALL qs_rho_get(rho, rho_ao=rho_ao)

         IF (do_admm_rpa) THEN
            CALL get_admm_env(admm_env, matrix_ks_aux_fit=matrix_ks_aux_fit, rho_aux_fit=rho_aux_fit, &
                              matrix_ks_aux_fit_hfx=matrix_ks_aux_fit_hfx)
            CALL qs_rho_get(rho_aux_fit, rho_ao=rho_ao_aux_fit)

            ! RPA/GW with ADMM for EXX or the exchange self-energy only implemented for
            ! ADMM_PURIFICATION_METHOD  NONE
            ! METHOD                    BASIS_PROJECTION
            ! in the admm section
            CPASSERT(admm_env%purification_method == do_admm_purify_none)
            CPASSERT(dft_control%admm_control%method == do_admm_basis_projection)
         END IF
      END IF

      nspins = dft_control%nspins

      ! safe ks matrix for later: we will transform matrix_ks
      ! to T-cell index and then to k-points for band structure calculation
      IF (do_kpoints_from_Gamma) THEN
         ! not yet there: open shell
         ALLOCATE (qs_env%mp2_env%ri_g0w0%matrix_ks(nspins))
         DO ispin = 1, nspins
            NULLIFY (qs_env%mp2_env%ri_g0w0%matrix_ks(ispin)%matrix)
            ALLOCATE (qs_env%mp2_env%ri_g0w0%matrix_ks(ispin)%matrix)
            CALL dbcsr_create(qs_env%mp2_env%ri_g0w0%matrix_ks(ispin)%matrix, &
                              template=matrix_ks(ispin)%matrix)
            CALL dbcsr_desymmetrize(matrix_ks(ispin)%matrix, &
                                    qs_env%mp2_env%ri_g0w0%matrix_ks(ispin)%matrix)

         END DO
      END IF

      IF (do_kpoints_cubic_RPA) THEN

         CALL allocate_matrix_ks_kp(matrix_ks_transl, matrix_ks_kp_re, matrix_ks_kp_im, kpoints)
         CALL transform_matrix_ks_to_kp(matrix_ks_transl, matrix_ks_kp_re, matrix_ks_kp_im, kpoints)

         DO ispin = 1, nspins
         DO i_img = 1, SIZE(matrix_ks_transl, 2)
            CALL dbcsr_set(matrix_ks_transl(ispin, i_img)%matrix, 0.0_dp)
         END DO
         END DO

      END IF

      ! initialize matrix_sigma_x_minus_vxc
      NULLIFY (matrix_sigma_x_minus_vxc)
      CALL dbcsr_allocate_matrix_set(matrix_sigma_x_minus_vxc, nspins, nkp)
      IF (do_kpoints_cubic_RPA) THEN
         NULLIFY (matrix_sigma_x_minus_vxc_im)
         CALL dbcsr_allocate_matrix_set(matrix_sigma_x_minus_vxc_im, nspins, nkp)
      END IF

      DO ispin = 1, nspins
         DO ikp = 1, nkp

            IF (do_kpoints_cubic_RPA) THEN

               ALLOCATE (matrix_sigma_x_minus_vxc(ispin, ikp)%matrix)
               CALL dbcsr_create(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, &
                                 template=matrix_ks_kp_re(1, 1)%matrix, &
                                 matrix_type=dbcsr_type_symmetric)

               CALL dbcsr_copy(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, matrix_ks_kp_re(ispin, ikp)%matrix)
               CALL dbcsr_set(matrix_ks_kp_re(ispin, ikp)%matrix, 0.0_dp)

               ALLOCATE (matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix)
               CALL dbcsr_create(matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix, &
                                 template=matrix_ks_kp_im(1, 1)%matrix, &
                                 matrix_type=dbcsr_type_antisymmetric)

               CALL dbcsr_copy(matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix, matrix_ks_kp_im(ispin, ikp)%matrix)
               CALL dbcsr_set(matrix_ks_kp_im(ispin, ikp)%matrix, 0.0_dp)

            ELSE

               ALLOCATE (matrix_sigma_x_minus_vxc(ispin, ikp)%matrix)
               CALL dbcsr_create(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, &
                                 template=matrix_ks(1)%matrix)

               CALL dbcsr_copy(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, matrix_ks(ispin)%matrix)
               CALL dbcsr_set(matrix_ks(ispin)%matrix, 0.0_dp)

            END IF

         END DO
      END DO

      ! set DFT functional to none and hfx_fraction to zero
      hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%HF")
      CALL section_vals_get(hfx_sections, explicit=do_hfx)

      IF (do_hfx) THEN
         hfx_fraction = qs_env%x_data(1, 1)%general_parameter%fraction
         qs_env%x_data(:, :)%general_parameter%fraction = 0.0_dp
      END IF
      xc_section => section_vals_get_subs_vals(input, "DFT%XC")
      CALL section_vals_val_get(xc_section, "XC_FUNCTIONAL%_SECTION_PARAMETERS_", &
                                i_val=myfun)
      CALL section_vals_val_set(xc_section, "XC_FUNCTIONAL%_SECTION_PARAMETERS_", &
                                i_val=xc_none)

      ! in ADMM, also set the XC functional for ADMM correction to none
      ! do not do this if we do ADMM for Sigma_x
      IF (dft_control%do_admm) THEN
         xc_section_admm_aux => section_vals_get_subs_vals(admm_env%xc_section_aux, &
                                                           "XC_FUNCTIONAL")
         CALL section_vals_val_get(xc_section_admm_aux, "_SECTION_PARAMETERS_", &
                                   i_val=myfun_aux)
         CALL section_vals_val_set(xc_section_admm_aux, "_SECTION_PARAMETERS_", &
                                   i_val=xc_none)

         ! the same for the primary basis
         xc_section_admm_prim => section_vals_get_subs_vals(admm_env%xc_section_primary, &
                                                            "XC_FUNCTIONAL")
         CALL section_vals_val_get(xc_section_admm_prim, "_SECTION_PARAMETERS_", &
                                   i_val=myfun_prim)
         CALL section_vals_val_set(xc_section_admm_prim, "_SECTION_PARAMETERS_", &
                                   i_val=xc_none)

         ! for ADMMQ/S, set the charge_constrain to false (otherwise wrong results)
         charge_constrain_tmp = .FALSE.
         IF (admm_env%charge_constrain) THEN
            admm_env%charge_constrain = .FALSE.
            charge_constrain_tmp = .TRUE.
         END IF

      END IF

      ! if we do ADMM for Sigma_x, we write the ADMM correction into matrix_ks_aux_fit
      ! and therefore we should set it to zero
      IF (do_admm_rpa) THEN
         DO ispin = 1, nspins
            CALL dbcsr_set(matrix_ks_aux_fit(ispin)%matrix, 0.0_dp)
         END DO
      END IF

      IF (.NOT. mp2_env%ri_g0w0%update_xc_energy) THEN
         energy_total = energy%total
         energy_exc = energy%exc
         energy_exc1 = energy%exc1
         energy_exc_aux_fit = energy%ex
         energy_exc1_aux_fit = energy%exc_aux_fit
         energy_ex = energy%exc1_aux_fit
      END IF

      ! Remove the Exchange-correlation energy contributions from the total energy
      energy%total = energy%total - (energy%exc + energy%exc1 + energy%ex + &
                                     energy%exc_aux_fit + energy%exc1_aux_fit)

      ! calculate KS-matrix without XC and without HF
      CALL qs_ks_build_kohn_sham_matrix(qs_env=qs_env, calculate_forces=.FALSE., &
                                        just_energy=.FALSE.)

      IF (.NOT. mp2_env%ri_g0w0%update_xc_energy) THEN
         energy%exc = energy_exc
         energy%exc1 = energy_exc1
         energy%exc_aux_fit = energy_ex
         energy%exc1_aux_fit = energy_exc_aux_fit
         energy%ex = energy_exc1_aux_fit
         energy%total = energy_total
      END IF

      ! set the DFT functional and HF fraction back
      CALL section_vals_val_set(xc_section, "XC_FUNCTIONAL%_SECTION_PARAMETERS_", &
                                i_val=myfun)
      IF (do_hfx) THEN
         qs_env%x_data(:, :)%general_parameter%fraction = hfx_fraction
      END IF

      IF (dft_control%do_admm) THEN
         xc_section_admm_aux => section_vals_get_subs_vals(admm_env%xc_section_aux, &
                                                           "XC_FUNCTIONAL")
         xc_section_admm_prim => section_vals_get_subs_vals(admm_env%xc_section_primary, &
                                                            "XC_FUNCTIONAL")

         CALL section_vals_val_set(xc_section_admm_aux, "_SECTION_PARAMETERS_", &
                                   i_val=myfun_aux)
         CALL section_vals_val_set(xc_section_admm_prim, "_SECTION_PARAMETERS_", &
                                   i_val=myfun_prim)
         IF (charge_constrain_tmp) THEN
            admm_env%charge_constrain = .TRUE.
         END IF
      END IF

      IF (do_kpoints_cubic_RPA) THEN
         CALL transform_matrix_ks_to_kp(matrix_ks_transl, matrix_ks_kp_re, matrix_ks_kp_im, kpoints)
      END IF

      ! remove the single-particle part (kin. En + Hartree pot) and change the sign
      DO ispin = 1, nspins
         IF (do_kpoints_cubic_RPA) THEN
            DO ikp = 1, nkp
               CALL dbcsr_add(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, matrix_ks_kp_re(ispin, ikp)%matrix, -1.0_dp, 1.0_dp)
               CALL dbcsr_add(matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix, matrix_ks_kp_im(ispin, ikp)%matrix, -1.0_dp, 1.0_dp)
            END DO
         ELSE
            CALL dbcsr_add(matrix_sigma_x_minus_vxc(ispin, 1)%matrix, matrix_ks(ispin)%matrix, -1.0_dp, 1.0_dp)
         END IF
      END DO

      IF (do_kpoints_cubic_RPA) THEN

         CALL transform_sigma_x_minus_vxc_to_MO_basis(kpoints, matrix_sigma_x_minus_vxc, &
                                                      matrix_sigma_x_minus_vxc_im, &
                                                      vec_Sigma_x_minus_vxc_gw, &
                                                      vec_Sigma_x_minus_vxc_gw_im, &
                                                      para_env, nmo, mp2_env)

      ELSE

         DO ispin = 1, nspins
            CALL dbcsr_set(matrix_ks(ispin)%matrix, 0.0_dp)
            IF (do_admm_rpa) THEN
               CALL dbcsr_set(matrix_ks_aux_fit(ispin)%matrix, 0.0_dp)
            END IF
         END DO

         hfx_sections => section_vals_get_subs_vals(input, "DFT%XC%WF_CORRELATION%RI_RPA%HF")

         CALL section_vals_get(hfx_sections, n_repetition=n_rep_hf)

         ! in most cases, we calculate the exchange self-energy here. But if we do only RI for
         ! the exchange self-energy, we do not calculate exchange here
         ehfx = 0.0_dp
         IF (.NOT. do_ri_Sigma_x) THEN

            CALL exx_pre_hfx(hfx_sections, qs_env%mp2_env%ri_rpa%x_data, qs_env%mp2_env%ri_rpa%reuse_hfx)
            calc_ints = .NOT. qs_env%mp2_env%ri_rpa%reuse_hfx

            ! add here HFX (=Sigma_exchange) to matrix_sigma_x_minus_vxc
            DO irep = 1, n_rep_hf
               IF (do_admm_rpa) THEN
                  matrix_ks_2d(1:nspins, 1:1) => matrix_ks_aux_fit(1:nspins)
                  rho_ao_2d(1:nspins, 1:1) => rho_ao_aux_fit(1:nspins)
               ELSE
                  matrix_ks_2d(1:nspins, 1:1) => matrix_ks(1:nspins)
                  rho_ao_2d(1:nspins, 1:1) => rho_ao(1:nspins)
               END IF

               IF (qs_env%mp2_env%ri_rpa%x_data(irep, 1)%do_hfx_ri) THEN
                  CALL hfx_ri_update_ks(qs_env, qs_env%mp2_env%ri_rpa%x_data(irep, 1)%ri_data, matrix_ks_2d, ehfx, &
                                        rho_ao=rho_ao_2d, geometry_did_change=calc_ints, nspins=nspins, &
                                        hf_fraction=qs_env%mp2_env%ri_rpa%x_data(irep, 1)%general_parameter%fraction)

                  IF (do_admm_rpa) THEN
                     !for ADMMS, we need the exchange matrix k(d) for both spins
                     DO ispin = 1, nspins
                        CALL dbcsr_copy(matrix_ks_aux_fit_hfx(ispin)%matrix, matrix_ks_2d(ispin, 1)%matrix, &
                                        name="HF exch. part of matrix_ks_aux_fit for ADMMS")
                     END DO
                  END IF
               ELSE
                  CALL integrate_four_center(qs_env, qs_env%mp2_env%ri_rpa%x_data, matrix_ks_2d, eh1, &
                                             rho_ao_2d, hfx_sections, &
                                             para_env, calc_ints, irep, .TRUE., &
                                             ispin=1)
                  ehfx = ehfx + eh1
               END IF
            END DO

            !ADMM XC correction
            IF (do_admm_rpa) THEN
               CALL calc_exx_admm_xc_contributions(qs_env=qs_env, &
                                                   matrix_prim=matrix_ks, &
                                                   matrix_aux=matrix_ks_aux_fit, &
                                                   x_data=qs_env%mp2_env%ri_rpa%x_data, &
                                                   exc=energy_xc_admm(1), &
                                                   exc_aux_fit=energy_xc_admm(2), &
                                                   calc_forces=.FALSE., &
                                                   use_virial=.FALSE.)
            END IF

            IF (do_kpoints_from_Gamma .AND. print_exx == gw_print_exx) THEN
               ALLOCATE (mat_exchange_for_kp_from_gamma(1))

               DO ispin = 1, 1
                  NULLIFY (mat_exchange_for_kp_from_gamma(ispin)%matrix)
                  ALLOCATE (mat_exchange_for_kp_from_gamma(ispin)%matrix)
                  CALL dbcsr_create(mat_exchange_for_kp_from_gamma(ispin)%matrix, template=matrix_ks(ispin)%matrix)
                  CALL dbcsr_desymmetrize(matrix_ks(ispin)%matrix, mat_exchange_for_kp_from_gamma(ispin)%matrix)
               END DO

            END IF

            CALL exx_post_hfx(qs_env, qs_env%mp2_env%ri_rpa%x_data, qs_env%mp2_env%ri_rpa%reuse_hfx)
         END IF

         energy_ex = ehfx

         ! transform Fock-Matrix (calculated in integrate_four_center, written in matrix_ks_aux_fit in case
         ! of ADMM) from ADMM basis to primary basis
         IF (do_admm_rpa) THEN
            CALL admm_mo_merge_ks_matrix(qs_env)
         END IF

         DO ispin = 1, nspins
            CALL dbcsr_add(matrix_sigma_x_minus_vxc(ispin, 1)%matrix, matrix_ks(ispin)%matrix, 1.0_dp, 1.0_dp)
         END DO

         ! safe matrix_sigma_x_minus_vxc for later: for example, we will transform matrix_sigma_x_minus_vxc
         ! to T-cell index and then to k-points for band structure calculation
         IF (do_kpoints_from_Gamma) THEN
            ! not yet there: open shell
            ALLOCATE (qs_env%mp2_env%ri_g0w0%matrix_sigma_x_minus_vxc(nspins))
            DO ispin = 1, nspins
               NULLIFY (qs_env%mp2_env%ri_g0w0%matrix_sigma_x_minus_vxc(ispin)%matrix)
               ALLOCATE (qs_env%mp2_env%ri_g0w0%matrix_sigma_x_minus_vxc(ispin)%matrix)
               CALL dbcsr_create(qs_env%mp2_env%ri_g0w0%matrix_sigma_x_minus_vxc(ispin)%matrix, &
                                 template=matrix_ks(ispin)%matrix)

               CALL dbcsr_desymmetrize(matrix_sigma_x_minus_vxc(ispin, 1)%matrix, &
                                       qs_env%mp2_env%ri_g0w0%matrix_sigma_x_minus_vxc(ispin)%matrix)

            END DO
         END IF

         CALL dbcsr_desymmetrize(matrix_ks(1)%matrix, mo_coeff_b)
         CALL dbcsr_set(mo_coeff_b, 0.0_dp)

         ! Transform matrix_sigma_x_minus_vxc to MO basis
         DO ispin = 1, nspins

            CALL get_mo_set(mo_set=mos_mp2(ispin), &
                            mo_coeff=mo_coeff, &
                            eigenvalues=mo_eigenvalues, &
                            nmo=nmo, &
                            homo=homo, &
                            nao=dimen)

            IF (ispin == 1) THEN

               ALLOCATE (vec_Sigma_x_minus_vxc_gw(nmo, nspins, nkp))
               vec_Sigma_x_minus_vxc_gw = 0.0_dp

               ALLOCATE (matrix_tmp_2_diag(dimen))
            END IF

            CALL dbcsr_set(mo_coeff_b, 0.0_dp)
            CALL copy_fm_to_dbcsr(mo_coeff, mo_coeff_b, keep_sparsity=.FALSE.)

            ! initialize matrix_tmp and matrix_tmp2
            IF (ispin == 1) THEN
               CALL dbcsr_create(matrix_tmp, template=mo_coeff_b)
               CALL dbcsr_copy(matrix_tmp, mo_coeff_b)
               CALL dbcsr_set(matrix_tmp, 0.0_dp)

               CALL dbcsr_create(matrix_tmp_2, template=mo_coeff_b)
               CALL dbcsr_copy(matrix_tmp_2, mo_coeff_b)
               CALL dbcsr_set(matrix_tmp_2, 0.0_dp)
            END IF

            gw_corr_lev_occ = mp2_env%ri_g0w0%corr_mos_occ
            gw_corr_lev_virt = mp2_env%ri_g0w0%corr_mos_virt

            ! If SVD is used to invert overlap matrix (for CHOLESKY OFF), some MOs are removed
            ! Therefore, setting the number of gw_corr_lev_virt simply to dimen - homo leads to index problems
            ! Instead, we take into account the removed MOs
            max_corr_lev_occ = homo
            max_corr_lev_virt = nmo - homo

            ! If BSE is invoked, manipulate corrected MO number
            IF (mp2_env%bse%do_bse) THEN
               ! Logic: If cutoff is negative, all MOs are included in BSE, i.e. we need to correct them all
               !        If cutoff is positive, we can reduce the number of MOs to be corrected and force gw_corr_lev_...
               !        to a sufficiently large number by setting it to -2 and read indices afterwards
               ! Handling for occupied levels
               IF (mp2_env%bse%bse_cutoff_occ < 0) THEN
                  gw_corr_lev_occ = -1
               ELSE
                  IF (gw_corr_lev_occ > 0) THEN
                     gw_corr_lev_occ = -2
                  END IF
               END IF
               ! Handling for virtual levels
               IF (mp2_env%bse%bse_cutoff_empty < 0) THEN
                  gw_corr_lev_virt = -1
               ELSE
                  IF (gw_corr_lev_virt > 0) THEN
                     gw_corr_lev_virt = -2
                  END IF
               END IF

               ! Obtain indices from DFT if gw_corr... are set to -2
               CALL determine_cutoff_indices(mo_eigenvalues, &
                                             homo, max_corr_lev_virt, &
                                             homo_reduced_bse, virtual_reduced_bse, &
                                             homo_startindex_bse, virtual_startindex_bse, &
                                             mp2_env)
               IF (gw_corr_lev_occ == -2) THEN
                  CPWARN("BSE cutoff overwrites user input for CORR_MOS_OCC")
                  gw_corr_lev_occ = homo_reduced_bse
               END IF
               IF (gw_corr_lev_virt == -2) THEN
                  CPWARN("BSE cutoff overwrites user input for CORR_MOS_VIRT")
                  gw_corr_lev_virt = virtual_reduced_bse
               END IF
            END IF

            ! if requested number of occ/virt levels for correction either exceed the number of
            ! occ/virt levels or the requested number is negative, default to correct all
            ! occ/virt level energies
            IF (gw_corr_lev_occ > homo .OR. gw_corr_lev_occ < 0) gw_corr_lev_occ = max_corr_lev_occ
            IF (gw_corr_lev_virt > max_corr_lev_virt .OR. gw_corr_lev_virt < 0) gw_corr_lev_virt = max_corr_lev_virt
            IF (ispin == 1) THEN
               mp2_env%ri_g0w0%corr_mos_occ = gw_corr_lev_occ
               mp2_env%ri_g0w0%corr_mos_virt = gw_corr_lev_virt
            ELSE IF (ispin == 2) THEN
               ! ensure that the total number of corrected MOs is the same for alpha and beta, important
               ! for parallelization
               IF (mp2_env%ri_g0w0%corr_mos_occ + mp2_env%ri_g0w0%corr_mos_virt /= &
                   gw_corr_lev_occ + gw_corr_lev_virt) THEN
                  gw_corr_lev_virt = mp2_env%ri_g0w0%corr_mos_occ + mp2_env%ri_g0w0%corr_mos_virt - gw_corr_lev_occ
               END IF
               mp2_env%ri_g0w0%corr_mos_occ_beta = gw_corr_lev_occ
               mp2_env%ri_g0w0%corr_mos_virt_beta = gw_corr_lev_virt

            END IF

            CALL dbcsr_multiply('N', 'N', 1.0_dp, matrix_sigma_x_minus_vxc(ispin, 1)%matrix, &
                                mo_coeff_b, 0.0_dp, matrix_tmp, first_column=homo + 1 - gw_corr_lev_occ, &
                                last_column=homo + gw_corr_lev_virt)

            CALL dbcsr_multiply('T', 'N', 1.0_dp, mo_coeff_b, &
                                matrix_tmp, 0.0_dp, matrix_tmp_2, first_row=homo + 1 - gw_corr_lev_occ, &
                                last_row=homo + gw_corr_lev_virt)

            CALL dbcsr_get_diag(matrix_tmp_2, matrix_tmp_2_diag)
            vec_Sigma_x_minus_vxc_gw(1:nmo, ispin, 1) = matrix_tmp_2_diag(1:nmo)

            CALL dbcsr_set(matrix_tmp, 0.0_dp)
            CALL dbcsr_set(matrix_tmp_2, 0.0_dp)

         END DO

         CALL para_env%sum(vec_Sigma_x_minus_vxc_gw)

      END IF

      CALL dbcsr_release(mo_coeff_b)
      CALL dbcsr_release(matrix_tmp)
      CALL dbcsr_release(matrix_tmp_2)
      IF (do_kpoints_cubic_RPA) THEN
         CALL dbcsr_deallocate_matrix_set(matrix_ks_kp_re)
         CALL dbcsr_deallocate_matrix_set(matrix_ks_kp_im)
      END IF

      DO ispin = 1, nspins
         DO ikp = 1, nkp
            CALL dbcsr_release_p(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix)
            IF (do_kpoints_cubic_RPA) THEN
               CALL dbcsr_release_p(matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix)
            END IF
         END DO
      END DO

      ALLOCATE (mp2_env%ri_g0w0%vec_Sigma_x_minus_vxc_gw(nmo, nspins, nkp))

      IF (print_exx == gw_print_exx) THEN

         IF (do_kpoints_from_Gamma) THEN

            gw_corr_lev_tot = gw_corr_lev_occ + gw_corr_lev_virt

            CALL get_qs_env(qs_env=qs_env, &
                            kpoints=kpoints)

            CALL trunc_coulomb_for_exchange(qs_env)

            CALL compute_kpoints(qs_env, kpoints, unit_nr)

            ALLOCATE (Eigenval_kp(nmo, 1, nspins))

            CALL get_bandstruc_and_k_dependent_MOs(qs_env, Eigenval_kp)

            CALL compute_minus_vxc_kpoints(qs_env)

            nkp_Sigma = SIZE(Eigenval_kp, 2)

            ALLOCATE (vec_Sigma_x(nmo, nkp_Sigma))
            vec_Sigma_x(:, :) = 0.0_dp

            CALL trafo_to_mo_and_kpoints(qs_env, &
                                         mat_exchange_for_kp_from_gamma(1)%matrix, &
                                         vec_Sigma_x(homo - gw_corr_lev_occ + 1:homo + gw_corr_lev_virt, :), &
                                         homo, gw_corr_lev_occ, gw_corr_lev_virt, 1)

            CALL dbcsr_release(mat_exchange_for_kp_from_gamma(1)%matrix)
            DEALLOCATE (mat_exchange_for_kp_from_gamma(1)%matrix)
            DEALLOCATE (mat_exchange_for_kp_from_gamma)

            DEALLOCATE (vec_Sigma_x_minus_vxc_gw)

            ALLOCATE (vec_Sigma_x_minus_vxc_gw(nmo, nspins, nkp_Sigma))

            vec_Sigma_x_minus_vxc_gw(:, 1, :) = vec_Sigma_x(:, :) + &
                                                qs_env%mp2_env%ri_g0w0%vec_Sigma_x_minus_vxc_gw(:, 1, :)

            kpoints_Sigma => qs_env%mp2_env%ri_rpa_im_time%kpoints_Sigma

         ELSE

            nkp_Sigma = 1

         END IF

         IF (unit_nr > 0) THEN

            ALLOCATE (Eigenval_kp_HF_at_DFT(nmo, nkp_Sigma))
            Eigenval_kp_HF_at_DFT(:, :) = Eigenval_kp(:, :, 1) + vec_Sigma_x_minus_vxc_gw(:, 1, :)

            min_direct_HF_at_DFT_gap = 100.0_dp

            WRITE (unit_nr, '(T3,A)') ''
            WRITE (unit_nr, '(T3,A)') 'Exchange energies'
            WRITE (unit_nr, '(T3,A)') '-----------------'
            WRITE (unit_nr, '(T3,A)') ''
            WRITE (unit_nr, '(T6,2A)') 'MO                        e_n^DFT          Sigma_x-vxc           e_n^HF@DFT'
            DO ikp = 1, nkp_Sigma
               IF (nkp_Sigma > 1) THEN
                  WRITE (unit_nr, '(T3,A)') ''
                  WRITE (unit_nr, '(T3,A7,I3,A3,I3,A8,3F7.3,A12,3F7.3)') 'Kpoint ', ikp, '  /', nkp_Sigma, &
                     '   xkp =', kpoints_Sigma%xkp(1, ikp), kpoints_Sigma%xkp(2, ikp), &
                     kpoints_Sigma%xkp(3, ikp), '  and  xkp =', -kpoints_Sigma%xkp(1, ikp), &
                     -kpoints_Sigma%xkp(2, ikp), -kpoints_Sigma%xkp(3, ikp)
                  WRITE (unit_nr, '(T3,A)') ''
               END IF
               DO n_level_gw = 1, gw_corr_lev_occ + gw_corr_lev_virt

                  n_level_gw_ref = n_level_gw + homo - gw_corr_lev_occ
                  IF (n_level_gw <= gw_corr_lev_occ) THEN
                     occ_virt = 'occ'
                  ELSE
                     occ_virt = 'vir'
                  END IF

                  eigval_dft = Eigenval_kp(n_level_gw_ref, ikp, 1)*evolt
                  exx_minus_vxc = REAL(vec_Sigma_x_minus_vxc_gw(n_level_gw_ref, 1, ikp)*evolt, kind=dp)
                  eigval_hf_at_dft = Eigenval_kp_HF_at_DFT(n_level_gw_ref, ikp)*evolt

                  WRITE (unit_nr, '(T4,I4,3A,3F21.3,3F21.3,3F21.3)') &
                     n_level_gw_ref, ' ( ', occ_virt, ')  ', eigval_dft, exx_minus_vxc, eigval_hf_at_dft

               END DO
               E_HOMO_GW = MAXVAL(Eigenval_kp_HF_at_DFT(homo - gw_corr_lev_occ + 1:homo, ikp))
               E_LUMO_GW = MINVAL(Eigenval_kp_HF_at_DFT(homo + 1:homo + gw_corr_lev_virt, ikp))
               E_GAP_GW = E_LUMO_GW - E_HOMO_GW
               IF (E_GAP_GW < min_direct_HF_at_DFT_gap) min_direct_HF_at_DFT_gap = E_GAP_GW
               WRITE (unit_nr, '(T3,A)') ''
               WRITE (unit_nr, '(T3,A,F53.2)') 'HF@DFT HOMO-LUMO gap (eV)', E_GAP_GW*evolt
               WRITE (unit_nr, '(T3,A)') ''
            END DO

            WRITE (unit_nr, '(T3,A)') ''
            WRITE (unit_nr, '(T3,A)') ''
            WRITE (unit_nr, '(T3,A,F63.3)') 'HF@DFT direct bandgap (eV)', min_direct_HF_at_DFT_gap*evolt

            WRITE (unit_nr, '(T3,A)') ''
            WRITE (unit_nr, '(T3,A)') 'End of exchange energies'
            WRITE (unit_nr, '(T3,A)') '------------------------'
            WRITE (unit_nr, '(T3,A)') ''

            CPABORT('Stop after printing exchange energies.')

         ELSE
            CALL para_env%sync()
         END IF

      END IF

      IF (print_exx == gw_read_exx) THEN

         CALL open_file(unit_number=iunit, file_name="exx.out")

         really_read_line = .FALSE.

         DO WHILE (.TRUE.)

            READ (iunit, '(A)') line

            IF (line == "  End of exchange energies              ") EXIT

            IF (really_read_line) THEN

               READ (line(1:7), *) n_level_gw_ref
               READ (line(17:40), *) tmp

               DO ikp = 1, SIZE(vec_Sigma_x_minus_vxc_gw, 3)
                  vec_Sigma_x_minus_vxc_gw(n_level_gw_ref, 1, ikp) = tmp/evolt
               END DO

            END IF

            IF (line == "     MO                    Sigma_x-vxc  ") really_read_line = .TRUE.

         END DO

         CALL close_file(iunit)

      END IF

      ! store vec_Sigma_x_minus_vxc_gw in the mp2_environment
      mp2_env%ri_g0w0%vec_Sigma_x_minus_vxc_gw(:, :, :) = vec_Sigma_x_minus_vxc_gw(:, :, :)

      ! clean up
      DEALLOCATE (matrix_sigma_x_minus_vxc, vec_Sigma_x_minus_vxc_gw)
      IF (do_kpoints_cubic_RPA) THEN
         DEALLOCATE (matrix_sigma_x_minus_vxc_im)
      END IF

      t2 = m_walltime()

      t3 = t2 - t1

      CALL timestop(handle)

   END SUBROUTINE compute_vec_Sigma_x_minus_vxc_gw

! **************************************************************************************************
!> \brief ...
!> \param kpoints ...
!> \param matrix_sigma_x_minus_vxc ...
!> \param matrix_sigma_x_minus_vxc_im ...
!> \param vec_Sigma_x_minus_vxc_gw ...
!> \param vec_Sigma_x_minus_vxc_gw_im ...
!> \param para_env ...
!> \param nmo ...
!> \param mp2_env ...
! **************************************************************************************************
   SUBROUTINE transform_sigma_x_minus_vxc_to_MO_basis(kpoints, matrix_sigma_x_minus_vxc, &
                                                      matrix_sigma_x_minus_vxc_im, vec_Sigma_x_minus_vxc_gw, &
                                                      vec_Sigma_x_minus_vxc_gw_im, para_env, nmo, mp2_env)

      TYPE(kpoint_type), POINTER                         :: kpoints
      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_sigma_x_minus_vxc, &
                                                            matrix_sigma_x_minus_vxc_im
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :, :)     :: vec_Sigma_x_minus_vxc_gw, &
                                                            vec_Sigma_x_minus_vxc_gw_im
      TYPE(mp_para_env_type), INTENT(IN)                 :: para_env
      INTEGER                                            :: nmo
      TYPE(mp2_type)                                     :: mp2_env

      CHARACTER(LEN=*), PARAMETER :: routineN = 'transform_sigma_x_minus_vxc_to_MO_basis'

      INTEGER :: dimen, gw_corr_lev_occ, gw_corr_lev_virt, handle, homo, i_global, iiB, ikp, &
         ispin, j_global, jjB, max_corr_lev_occ, max_corr_lev_virt, ncol_local, nkp, nrow_local, &
         nspins
      INTEGER, DIMENSION(2)                              :: kp_range
      INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
      REAL(KIND=dp)                                      :: imval, reval
      TYPE(cp_cfm_type)                                  :: cfm_mos, cfm_sigma_x_minus_vxc, &
                                                            cfm_sigma_x_minus_vxc_mo_basis, cfm_tmp
      TYPE(cp_fm_struct_type), POINTER                   :: matrix_struct
      TYPE(cp_fm_type)                                   :: fwork_im, fwork_re
      TYPE(kpoint_env_type), POINTER                     :: kp
      TYPE(mo_set_type), POINTER                         :: mo_set, mo_set_im, mo_set_re

      CALL timeset(routineN, handle)

      mo_set => kpoints%kp_env(1)%kpoint_env%mos(1, 1)
      CALL get_mo_set(mo_set, nmo=nmo)

      nspins = SIZE(matrix_sigma_x_minus_vxc, 1)
      CALL get_kpoint_info(kpoints, nkp=nkp, kp_range=kp_range)

      ! if this CPASSERT is wrong, please make sure that the kpoint group size PARALLEL_GROUP_SIZE
      ! in the kpoint environment &DFT &KPOINTS is -1
      CPASSERT(kp_range(1) == 1 .AND. kp_range(2) == nkp)

      ALLOCATE (vec_Sigma_x_minus_vxc_gw(nmo, nspins, nkp))
      vec_Sigma_x_minus_vxc_gw = 0.0_dp

      ALLOCATE (vec_Sigma_x_minus_vxc_gw_im(nmo, nspins, nkp))
      vec_Sigma_x_minus_vxc_gw_im = 0.0_dp

      CALL cp_fm_get_info(mo_set%mo_coeff, matrix_struct=matrix_struct)
      CALL cp_fm_create(fwork_re, matrix_struct)
      CALL cp_fm_create(fwork_im, matrix_struct)
      CALL cp_cfm_create(cfm_mos, matrix_struct)
      CALL cp_cfm_create(cfm_sigma_x_minus_vxc, matrix_struct)
      CALL cp_cfm_create(cfm_sigma_x_minus_vxc_mo_basis, matrix_struct)
      CALL cp_cfm_create(cfm_tmp, matrix_struct)

      CALL cp_cfm_get_info(matrix=cfm_sigma_x_minus_vxc_mo_basis, &
                           nrow_local=nrow_local, &
                           ncol_local=ncol_local, &
                           row_indices=row_indices, &
                           col_indices=col_indices)

      ! Transform matrix_sigma_x_minus_vxc to MO basis
      DO ikp = 1, nkp

         kp => kpoints%kp_env(ikp)%kpoint_env

         DO ispin = 1, nspins

            ! v_xc_n to fm matrix
            CALL copy_dbcsr_to_fm(matrix_sigma_x_minus_vxc(ispin, ikp)%matrix, fwork_re)
            CALL copy_dbcsr_to_fm(matrix_sigma_x_minus_vxc_im(ispin, ikp)%matrix, fwork_im)

            CALL cp_cfm_scale_and_add_fm(z_zero, cfm_sigma_x_minus_vxc, z_one, fwork_re)
            CALL cp_cfm_scale_and_add_fm(z_one, cfm_sigma_x_minus_vxc, gaussi, fwork_im)

            ! get real part (1) and imag. part (2) of the mo coeffs
            mo_set_re => kp%mos(1, ispin)
            mo_set_im => kp%mos(2, ispin)

            CALL cp_cfm_scale_and_add_fm(z_zero, cfm_mos, z_one, mo_set_re%mo_coeff)
            CALL cp_cfm_scale_and_add_fm(z_one, cfm_mos, gaussi, mo_set_im%mo_coeff)

            ! tmp = V(k)*C(k)
            CALL parallel_gemm('N', 'N', nmo, nmo, nmo, z_one, cfm_sigma_x_minus_vxc, &
                               cfm_mos, z_zero, cfm_tmp)

            ! V_n(k) = C^H(k)*tmp
            CALL parallel_gemm('C', 'N', nmo, nmo, nmo, z_one, cfm_mos, cfm_tmp, &
                               z_zero, cfm_sigma_x_minus_vxc_mo_basis)

            DO jjB = 1, ncol_local

               j_global = col_indices(jjB)

               DO iiB = 1, nrow_local

                  i_global = row_indices(iiB)

                  IF (j_global == i_global .AND. i_global <= nmo) THEN

                     reval = REAL(cfm_sigma_x_minus_vxc_mo_basis%local_data(iiB, jjB), kind=dp)
                     imval = AIMAG(cfm_sigma_x_minus_vxc_mo_basis%local_data(iiB, jjB))

                     vec_Sigma_x_minus_vxc_gw(i_global, ispin, ikp) = reval
                     vec_Sigma_x_minus_vxc_gw_im(i_global, ispin, ikp) = imval

                  END IF

               END DO

            END DO

         END DO

      END DO

      CALL para_env%sum(vec_Sigma_x_minus_vxc_gw)
      CALL para_env%sum(vec_Sigma_x_minus_vxc_gw_im)
      ! also adjust in the case of kpoints too big gw_corr_lev_occ and gw_corr_lev_virt
      DO ispin = 1, nspins
         CALL get_mo_set(mo_set=kpoints%kp_env(1)%kpoint_env%mos(ispin, 1), &
                         homo=homo, nao=dimen)
         ! If SVD is used to invert overlap matrix (for CHOLESKY OFF), some MOs are removed
         ! Therefore, setting the number of gw_corr_lev_virt simply to dimen - homo leads to index problems
         ! Instead, we take into account the removed MOs
         max_corr_lev_occ = homo
         max_corr_lev_virt = nmo - homo

         gw_corr_lev_occ = mp2_env%ri_g0w0%corr_mos_occ
         gw_corr_lev_virt = mp2_env%ri_g0w0%corr_mos_virt
         ! if corrected occ/virt levels exceed the number of occ/virt levels or are negative,
         ! correct all occ/virt level energies
         IF (gw_corr_lev_occ > homo .OR. gw_corr_lev_occ < 0) gw_corr_lev_occ = max_corr_lev_occ
         IF (gw_corr_lev_virt > max_corr_lev_virt .OR. gw_corr_lev_virt < 0) gw_corr_lev_virt = max_corr_lev_virt
         IF (ispin == 1) THEN
            mp2_env%ri_g0w0%corr_mos_occ = gw_corr_lev_occ
            mp2_env%ri_g0w0%corr_mos_virt = gw_corr_lev_virt
         ELSE IF (ispin == 2) THEN
            ! ensure that the total number of corrected MOs is the same for alpha and beta, important
            ! for parallelization
            IF (mp2_env%ri_g0w0%corr_mos_occ + mp2_env%ri_g0w0%corr_mos_virt /= &
                gw_corr_lev_occ + gw_corr_lev_virt) THEN
               gw_corr_lev_virt = mp2_env%ri_g0w0%corr_mos_occ + mp2_env%ri_g0w0%corr_mos_virt - gw_corr_lev_occ
            END IF
            mp2_env%ri_g0w0%corr_mos_occ_beta = gw_corr_lev_occ
            mp2_env%ri_g0w0%corr_mos_virt_beta = gw_corr_lev_virt
         END IF
      END DO

      CALL cp_fm_release(fwork_re)
      CALL cp_fm_release(fwork_im)
      CALL cp_cfm_release(cfm_mos)
      CALL cp_cfm_release(cfm_sigma_x_minus_vxc)
      CALL cp_cfm_release(cfm_sigma_x_minus_vxc_mo_basis)
      CALL cp_cfm_release(cfm_tmp)

      CALL timestop(handle)

   END SUBROUTINE transform_sigma_x_minus_vxc_to_MO_basis

! **************************************************************************************************
!> \brief ...
!> \param matrix_ks_transl ...
!> \param matrix_ks_kp_re ...
!> \param matrix_ks_kp_im ...
!> \param kpoints ...
! **************************************************************************************************
   SUBROUTINE transform_matrix_ks_to_kp(matrix_ks_transl, matrix_ks_kp_re, matrix_ks_kp_im, kpoints)

      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks_transl, matrix_ks_kp_re, &
                                                            matrix_ks_kp_im
      TYPE(kpoint_type), POINTER                         :: kpoints

      CHARACTER(len=*), PARAMETER :: routineN = 'transform_matrix_ks_to_kp'

      INTEGER                                            :: handle, ikp, ispin, nkp, nspin
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: xkp
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl

      CALL timeset(routineN, handle)

      NULLIFY (sab_nl)
      CALL get_kpoint_info(kpoints, nkp=nkp, xkp=xkp, sab_nl=sab_nl, cell_to_index=cell_to_index)

      CPASSERT(ASSOCIATED(sab_nl))

      nspin = SIZE(matrix_ks_transl, 1)

      DO ikp = 1, nkp
         DO ispin = 1, nspin

            CALL dbcsr_set(matrix_ks_kp_re(ispin, ikp)%matrix, 0.0_dp)
            CALL dbcsr_set(matrix_ks_kp_im(ispin, ikp)%matrix, 0.0_dp)
            CALL rskp_transform(rmatrix=matrix_ks_kp_re(ispin, ikp)%matrix, &
                                cmatrix=matrix_ks_kp_im(ispin, ikp)%matrix, &
                                rsmat=matrix_ks_transl, ispin=ispin, &
                                xkp=xkp(1:3, ikp), cell_to_index=cell_to_index, sab_nl=sab_nl)

         END DO
      END DO

      CALL timestop(handle)

   END SUBROUTINE transform_matrix_ks_to_kp

! **************************************************************************************************
!> \brief ...
!> \param matrix_ks_transl ...
!> \param matrix_ks_kp_re ...
!> \param matrix_ks_kp_im ...
!> \param kpoints ...
! **************************************************************************************************
   SUBROUTINE allocate_matrix_ks_kp(matrix_ks_transl, matrix_ks_kp_re, matrix_ks_kp_im, kpoints)

      TYPE(dbcsr_p_type), DIMENSION(:, :), POINTER       :: matrix_ks_transl, matrix_ks_kp_re, &
                                                            matrix_ks_kp_im
      TYPE(kpoint_type), POINTER                         :: kpoints

      CHARACTER(len=*), PARAMETER :: routineN = 'allocate_matrix_ks_kp'

      INTEGER                                            :: handle, ikp, ispin, nkp, nspin
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: xkp
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_nl

      CALL timeset(routineN, handle)

      NULLIFY (sab_nl)
      CALL get_kpoint_info(kpoints, nkp=nkp, xkp=xkp, sab_nl=sab_nl, cell_to_index=cell_to_index)

      CPASSERT(ASSOCIATED(sab_nl))

      nspin = SIZE(matrix_ks_transl, 1)

      NULLIFY (matrix_ks_kp_re, matrix_ks_kp_im)
      CALL dbcsr_allocate_matrix_set(matrix_ks_kp_re, nspin, nkp)
      CALL dbcsr_allocate_matrix_set(matrix_ks_kp_im, nspin, nkp)

      DO ikp = 1, nkp
      DO ispin = 1, nspin

         ALLOCATE (matrix_ks_kp_re(ispin, ikp)%matrix)
         ALLOCATE (matrix_ks_kp_im(ispin, ikp)%matrix)

         CALL dbcsr_create(matrix_ks_kp_re(ispin, ikp)%matrix, &
                           template=matrix_ks_transl(1, 1)%matrix, &
                           matrix_type=dbcsr_type_symmetric)
         CALL dbcsr_create(matrix_ks_kp_im(ispin, ikp)%matrix, &
                           template=matrix_ks_transl(1, 1)%matrix, &
                           matrix_type=dbcsr_type_antisymmetric)

         CALL cp_dbcsr_alloc_block_from_nbl(matrix_ks_kp_re(ispin, ikp)%matrix, sab_nl)
         CALL cp_dbcsr_alloc_block_from_nbl(matrix_ks_kp_im(ispin, ikp)%matrix, sab_nl)

         CALL dbcsr_set(matrix_ks_kp_re(ispin, ikp)%matrix, 0.0_dp)
         CALL dbcsr_set(matrix_ks_kp_im(ispin, ikp)%matrix, 0.0_dp)

      END DO
      END DO

      CALL timestop(handle)

   END SUBROUTINE allocate_matrix_ks_kp

END MODULE rpa_gw_sigma_x

